package com.pan.insist.filter;

import com.pan.insist.constant.CommonConstant;
import com.pan.insist.config.WebConfig;
import com.pan.insist.util.SpringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.DependsOn;
import org.springframework.core.annotation.Order;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;


/**
 * 过滤器---CSRF防护
 *
 * @author kaiji
 */
@Order(2)
@DependsOn("springUtil")
@WebFilter(filterName = "csrfFilter", urlPatterns = "/*")
public class CsrfFilter implements Filter {
    private static final Logger logger = LoggerFactory.getLogger(CsrfFilter.class);

    /**
     * 需要排除的路径
     */
    private static final Set<String> ALLOWED_PATHS = Collections.unmodifiableSet(new HashSet<>(
            Arrays.asList("", "/", "/login/**", "/static/**")));

    @Override
    public void init(FilterConfig filterConfig) {
		// SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(this, filterConfig.getServletContext());
    }

    @Override
    public void doFilter(ServletRequest servletRequest,
                         ServletResponse servletResponse, FilterChain filterChain)
            throws IOException, ServletException {

        WebConfig webConfig;
        try {
            webConfig = SpringUtil.getBean(WebConfig.class);
        } catch (Exception e) {
            logger.error("获取webConfig 异常!原因{}", e.getMessage());
            webConfig = new WebConfig();
            webConfig.setEnvironment(CommonConstant.Environment.LOCAL);
        }

        // logger.info("当前环境是：{}", webConfig.getEnvironment());
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;
        // logger.info("访问路径为：{}", request.getRequestURI());
        // 线上环境启用CSRF防护
        if (CommonConstant.Environment.ONLINE.equals(webConfig.getEnvironment())) {
            String path = request.getRequestURI().substring(request.getContextPath().length()).replaceAll("[/]+$", "");
            // 排除掉部分路径
            if (isAffowedPath(path)) {
                filterChain.doFilter(request, response);
                return;
            }

            String host = request.getHeader("host");
            String referer = request.getHeader("Referer");
            // Referer不为空才能正常访问
            if ((referer != null) && (referer.trim().startsWith(webConfig.getSystemUrl()))) {
                // 验证host
                String validateHost = webConfig.getHost();
                if (validateHost.equals(host)) {
                    filterChain.doFilter(servletRequest, servletResponse);
                } else {
                    redirectReferer(request, response, host, referer, webConfig);
                }
            }
            // Referer为空则直接跳转首页
            else {
                redirectReferer(request, response, host, referer, webConfig);
            }
        } else {
            filterChain.doFilter(request, response);
        }
    }

    private void redirectReferer(HttpServletRequest request, HttpServletResponse response,
                                 String host, String referer, WebConfig webConfig) throws IOException {
        logger.info("====================");
        logger.info("当前请求:" + request.getRequestURL());
        logger.info("当前host:" + host);
        logger.info("当前referer:" + referer);
        logger.info("====================");
        response.setHeader("host", webConfig.getHost());
        response.setHeader("location", webConfig.getSystemUrl());
        response.sendRedirect(webConfig.getSystemUrl());
    }


    /**
     * 判断path是否在白名单内
     *
     * @param path 路径
     */
    private boolean isAffowedPath(String path) {
        for (String afflowedPath : ALLOWED_PATHS) {
            if (afflowedPath.contains("/**")) {
                afflowedPath = afflowedPath.replace("/**", "");
                if (path.startsWith(afflowedPath)) {
                    return true;
                }
            } else {
                if (path.equals(afflowedPath)) {
                    return true;
                }
            }
        }
        return false;
    }

    @Override
    public void destroy() {
    }

}
